from numpy import *
from scipy import *
from visual import *
import time

      
class PlaneWave:
    def __init__(self, sizeOfLattice, center, sigma, modex, modey):
        self.sizeOfLattice = sizeOfLattice
        self.center = center
        self.sigma = sigma

        self.modex = modex
        self.modey = modey
        
        self.tonorm = 1.0
        self.coeff = 1.0

        self.p = 2*pi/self.sizeOfLattice*vector(self.modex,self.modey,0.0)
        self.spinor = [0.0,0.0,0.0,0.0]

    def returnSpinor(self,x,y):
        s1upWf = self.tonorm*self.coeff*e**(-1j*dot(self.p,vector(x,y,0)))
        s1downWf = 0
        s2upWf = self.tonorm*self.coeff*e**(-1j*dot(self.p,vector(x,y,0)))
        s2downWf = 0
      
        return [s1upWf,s1downWf,s2upWf,s2downWf]

    def setNorm(self,tonorm):
        self.tonorm = tonorm



class Basis:
    def __init__(self, sizeOfLattice, sizeOfBasis, center, sigma):
        self.sizeOfLattice = sizeOfLattice
        self.sizeOfBasis = sizeOfBasis
        self.sigma = sigma

        self.basisElements = [0]*sizeOfBasis 
        for modex in xrange(sizeOfBasis):
            self.basisElements[modex] = [0]*sizeOfBasis
            for modey in xrange(sizeOfBasis):
                self.basisElements[modex][modey] = PlaneWave(sizeOfLattice, center, sigma, modex, modey)

    def normalize(self):
        norm = 0.0
        for x in xrange(self.sizeOfLattice):
            for y in xrange(self.sizeOfLattice):
                norm += self.getDensity(x,y)
        print "Normalization:", 1/sqrt(norm)
        for modex in xrange(self.sizeOfBasis):
            for modey in xrange(self.sizeOfBasis):
                self.basisElements[modex][modey].setNorm(1/sqrt(norm))

    def getDensity(self,x,y):
        totalSpinor = [0.0,0.0,0.0,0.0]
        for modex in xrange(self.sizeOfBasis):
            for modey in xrange(self.sizeOfBasis):
                totalSpinor += self.basisElements[modex][modey].returnSpinor(x,y)
        return dot(conjugate(totalSpinor),totalSpinor).real



class DensityPoint:
    def __init__(self):
        self.visibility = 1
        self.point = sphere(visible=self.visibility, radius=.001)

    def setAttributes(self,position,size,color):
        self.point.pos = position
        self.point.radius = size
        self.point.color = color

    def toggleVisibility(self):
        self.visibility = (1+self.visibility)%2



class Lattice:
    def __init__(self, basis, sizeOfLattice, sizeOfBasis, center):
        self.sizeOfLattice = sizeOfLattice
        self.sizeOfBasis = sizeOfBasis
        self.center = center

        scale = sizeOfLattice

        self.pointLattice=[None]*sizeOfLattice
        for x in xrange(sizeOfLattice):
            self.pointLattice[x]=[None]*sizeOfLattice
            for y in xrange(sizeOfLattice):
                position = vector(x,y,0.0)
                density = basis.getDensity(x,y)
                size = scale*density
                color = (1.0,1.0,1.0)
                self.pointLattice[x][y] = DensityPoint()
                self.pointLattice[x][y].setAttributes(position,size,color)


sizeOfLattice = 24
sizeOfBasis = 8
sigma = sizeOfLattice/2.0
center = vector(sizeOfLattice/2.0,sizeOfLattice/2.0,0.0)

basis = Basis(sizeOfLattice, sizeOfBasis, center, sigma)
basis.normalize()

scene = display(title='Density Plot', x=0, y=0, width=600, height=600, center=center, background=(0,0,0))

lattice = Lattice(basis, sizeOfLattice, sizeOfBasis, center)

